"""Generate a markdown summary of the results of a benchmarking run."""
import argparse
import pathlib
from collections import Counter
from functools import lru_cache
from typing import Generator, Sequence, cast

import datasets
import numpy as np
from huggingface_sb3 import EnvironmentName
from rliable import library as rly
from rliable import metrics

from imitation.data import rollout, types
from imitation.data.huggingface_utils import TrajectoryDatasetSequence
from imitation.util.sacred_file_parsing import (
    find_sacred_runs,
    group_runs_by_algo_and_env,
)


@lru_cache(maxsize=None)
def get_random_agent_score(env: str):
    stats = rollout.rollout_stats(
        cast(
            Sequence[types.TrajectoryWithRew],
            TrajectoryDatasetSequence(
                datasets.load_dataset(
                    f"HumanCompatibleAI/random-{EnvironmentName(env)}",
                )["train"],
            ),
        ),
    )
    return stats["monitor_return_mean"]


def print_markdown_summary(path: pathlib.Path) -> Generator[str, None, None]:
    if not path.exists():
        raise NotADirectoryError(f"Path {path} does not exist.")

    yield "# Benchmark Summary"
    yield ""
    yield (
        f"This is a summary of the sacred runs in `{path}` generated by "
        f"`sacred_output_to_markdown_summary.py`."
    )

    runs_by_algo_and_env = group_runs_by_algo_and_env(path)
    algos = sorted(runs_by_algo_and_env.keys())

    status_counts = Counter((run["status"] for _, run in find_sacred_runs(path)))
    statuses = sorted(list(status_counts))
    # Note: we only print the status section if there are multiple statuses
    if not (len(statuses) == 1 and statuses[0] == "COMPLETED"):
        yield "## Run status" ""
        yield "Status | Count"
        yield "--- | ---"
        for status in statuses:
            yield f"{status} | {status_counts[status]}"
        yield ""

        yield "## Detailed Run Status"
        yield f"Algorithm | Environment | {' | '.join(statuses)}"
        yield "--- | --- " + " | --- " * len(statuses)
        for algo in algos:
            envs = sorted(runs_by_algo_and_env[algo].keys())
            for env in envs:
                status_counts = Counter(
                    (run["status"] for run in runs_by_algo_and_env[algo][env]),
                )
                yield (
                    f"{algo} | {env} | "
                    f"{' | '.join([str(status_counts[status]) for status in statuses])}"
                )

    yield "## Scores"
    yield ""
    yield (
        "The scores are normalized based on the performance of a random agent as the"
        " baseline and the expert as the maximum possible score as explained "
        "[in this blog post](https://araffin.github.io/post/rliable/):"
    )
    yield "> `(score - random_score) / (expert_score - random_score)`"
    yield ""
    yield (
        "Aggregate scores and confidence intervals are computed using the "
        "[rliable library](https://agarwl.github.io/rliable/)."
    )

    for algo in algos:
        yield f"### {algo.upper()}"
        yield "Environment | Score (mean/std)| Normalized Score (mean/std) | N"
        yield " --- | --- | --- | --- "
        envs = sorted(runs_by_algo_and_env[algo].keys())
        accumulated_normalized_scores = []
        for env in envs:
            scores = [
                run["result"]["imit_stats"]["monitor_return_mean"]
                for run in runs_by_algo_and_env[algo][env]
            ]
            expert_scores = [
                run["result"]["expert_stats"]["monitor_return_mean"]
                for run in runs_by_algo_and_env[algo][env]
            ]
            random_score = get_random_agent_score(env)
            normalized_score = [
                (score - random_score) / (expert_score - random_score)
                for score, expert_score in zip(scores, expert_scores)
            ]
            accumulated_normalized_scores.append(normalized_score)

            yield (
                f"{env} | "
                f"{np.mean(scores):.3f} / {np.std(scores):.3f} | "
                f"{np.mean(normalized_score):.3f} / {np.std(normalized_score):.3f} | "
                f"{len(scores)}"
            )

        aggregate_scores, aggregate_score_cis = rly.get_interval_estimates(
            {"normalized_score": np.asarray(accumulated_normalized_scores).T},
            lambda x: np.array([metrics.aggregate_mean(x), metrics.aggregate_iqm(x)]),
            reps=1000,
        )
        yield ""
        yield "#### Aggregate Normalized scores"

        yield "Metric | Value | 95% CI"
        yield " --- | --- | --- "
        yield (
            f"Mean | "
            f"{aggregate_scores['normalized_score'][0]:.3f} | "
            f"[{aggregate_score_cis['normalized_score'][0][0]:.3f}, "
            f"{aggregate_score_cis['normalized_score'][0][1]:.3f}]"
        )
        yield (
            f"IQM | "
            f"{aggregate_scores['normalized_score'][1]:.3f} | "
            f"[{aggregate_score_cis['normalized_score'][1][0]:.3f}, "
            f"{aggregate_score_cis['normalized_score'][1][1]:.3f}]"
        )
        yield ""


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Generate a markdown summary of the results of a benchmarking run.",
    )
    parser.add_argument("path", type=pathlib.Path)
    parser.add_argument("--output", type=pathlib.Path, default="summary.md")

    args = parser.parse_args()

    with open(args.output, "w") as fh:
        for line in print_markdown_summary(pathlib.Path(args.path)):
            fh.write(line)
            fh.write("\n")
            fh.flush()
